"""
自定义DecompositionQueryRetriever检索器，把子问题的生成、获得子问题的答案组合起来，然后把组合答案作为上下文，调用大模型回答问题
"""
from pprint import pprint

from langchain.retrievers.multi_query import LineListOutputParser
from langchain_community.vectorstores import Chroma
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate, PromptTemplate, ChatPromptTemplate
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import  Runnable, RunnableLambda

from models import get_ollama_embeddings_client, get_ds_model_client

# 文档设计范例
documents = [
    Document(page_content="""
    番茄炒蛋的食材：
    - 新鲜鸡蛋：3-4个（根据人数调整）\n
    - 番茄：2-3个中等大小\n
    - 盐：适量\n
    - 白糖：一小勺（可选，用于提鲜）\n
    - 食用油：适量\n
    - 葱花：少许（可选，用于增香）\n
    这些是最基本的材料，当然也可以根据个人口味添加其他调料或配料。
    """),
    Document(page_content="""
    番茄炒蛋的步骤：鸡蛋打入碗中，加入少许盐，用筷子或打蛋器充分搅拌均匀；- 番茄洗净后切成小块备用。\n
    3. **炒鸡蛋**：锅内倒入适量食用油加热至温热状态，然后将搅拌好的鸡蛋液缓缓倒入锅中。待鸡蛋凝固时轻轻翻动几下，让其受热均匀直至完全熟透，随后盛出备用。\n
    4. **炒番茄**：在同一锅里留下的底油中放入切好的番茄块，中小火慢慢翻炒至出汁，可根据个人口味加一点点白糖提鲜。\n
    5. **合炒**：当番茄炒至软烂并开始释放大量汤汁时，再把之前炒好的鸡蛋倒回锅里，快速与番茄混合均匀，同时加入适量的盐调味。如果喜欢的话还可以撒上一些葱花增加香气。\n
    6. **完成**：最后检查一下味道是否合适，确认无误后即可关火装盘享用美味的番茄炒蛋啦！
    """),
    Document(page_content="""
    技巧与注意事项：
    1. **选材**：选择新鲜的鸡蛋和成熟的番茄。新鲜的食材是做好这道菜的基础。\n
    2. **打蛋液**：将鸡蛋打入碗中后加入少许盐（根据个人口味调整），然后充分搅拌均匀。这样做可以让蛋更加松软且味道更佳。\n
    3. **处理番茄**：番茄最好先用开水稍微焯一下皮，然后去皮切块。这样可以去除表皮的硬质部分，让番茄更容易入味，并且口感更好。\n
    4. **热锅冷油**：先用中小火把锅烧热，再倒入适量食用油，待油温五成热时下蛋液。这样的做法可以使蛋快速凝固形成漂亮的形状而不易粘锅。\n
    5. **分步烹饪**：通常建议先炒鸡蛋至半熟状态取出备用；接着利用剩下的底油继续翻炒番茄至出汁，最后再将之前炒好的鸡蛋倒回锅里与番茄混合均匀加热即可。\n
    6. **调味品**：除了基本的盐之外，还可以根据喜好添加少量糖来提鲜或者一点酱油增色添香。注意调味料不宜过多以免掩盖了食材本身的味道。\n
    7. **出锅前加葱花**：如果喜欢的话，在即将完成时撒上一些葱花不仅能增加菜品色泽还能增添香气。
    """)
]

DEFAULT_QUERY_PROMPT = PromptTemplate.from_template("""你是一名AI语言模型助理。你的任务是将输入问题分解成3个子问题，通过一个个解决这些子问题从而解决完整的问题。
            子问题需要在矢量数据库中检索相关文档。通过分解用户问题生成子问题，你的目标是帮助用户克服基于距离的相似性搜索的一些局限性。
            请提供这些用换行符分隔的子问题本身，不需要额外内容。
            原始问题: {question}""")

DEFAULT_SUB_QUESTION_PROMPT = PromptTemplate.from_template("""要解决主要问题{question}，需要先解决子问题{sub_question}。
以下是为支持您的推理而提供的参考文档：{documents}。请直接给出当前子问题的答案。不需要额外内容。
""")


class DecompositionQueryRetriever(BaseRetriever):
    # 向量库检索器
    retriever: BaseRetriever
    # 生成子问题链
    llm_chain: Runnable
    # 解决子问题的链
    sub_llm_chain: Runnable

    # 初始化的步骤
    @classmethod
    def from_llm(cls,
                 retriever: BaseRetriever,
                 llm: BaseLanguageModel,
                 prompt: BasePromptTemplate = DEFAULT_QUERY_PROMPT,
                 sub_prompt: BasePromptTemplate = DEFAULT_SUB_QUESTION_PROMPT
                 ) -> "DecompositionQueryRetriever":
        out_parser = LineListOutputParser()
        llm_chain = prompt | llm | out_parser
        sub_llm_chain = sub_prompt | llm
        return cls(
            retriever=retriever,
            llm_chain=llm_chain,
            sub_llm_chain=sub_llm_chain
        )

    # 生成子问题
    def generate_sub_questions(self, question: str) -> list[str]:
        sub_question = self.llm_chain.invoke({"question": question})
        print(f"Generated sub_question: {sub_question}")
        return sub_question

    # 解决子问题
    def retrieve_documents(self, question: str, sub_questions: list[str]) -> list[Document]:
        sub_chain = RunnableLambda(
            lambda sub_question: self.sub_llm_chain.invoke(
                {
                    "question": question,
                    "sub_question": sub_question,
                    "documents": [doc.page_content for doc in self.retriever.invoke(sub_question)]
                }
            )
        )
        response = sub_chain.batch(sub_questions)
        return [
            Document(page_content=sub_q + "\n" + res.content)
            for sub_q, res in zip(sub_questions, response)
        ]

    # 把生成子问题和解决子问题组合起来
    def _get_relevant_documents(self, query: str, *, run_manager: CallbackManagerForRetrieverRun) -> list[Document]:
        sub_questions = self.generate_sub_questions(query)
        return self.retrieve_documents(query, sub_questions)

embeddings_client = get_ollama_embeddings_client()
vector_store = Chroma.from_documents(documents, embedding=embeddings_client, collection_name="decomposition")
retriever = vector_store.as_retriever(search_kwargs={"k": 1})
llm = get_ds_model_client()

decomposition_retriever = DecompositionQueryRetriever.from_llm(retriever=retriever, llm=llm)
docs = decomposition_retriever.invoke("新手如何制作番茄炒蛋？")
print("-------------检索到的文档（拆解后）--------------")
pprint(docs)

# 把组合答案作为上下文传递给大模型回答问题
# 创建prompt模板
template = """请根据以下文档回答问题:
### 文档:
{context}
### 问题:
{question}
"""
# 由模板生成prompt
prompt = ChatPromptTemplate.from_template(template)
chain = prompt | llm
print("-------------回答--------------")
question = "新手如何制作番茄炒蛋？"
response = chain.invoke({"context": [doc.page_content for doc in docs], "question": question})
print(response.content)
